{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/BERT.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT in TransformerLens\n",
    "This demo shows how to use BERT in TransformerLens for the Masked Language Modelling and Next Sentence Prediction task."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup\n",
    "(No need to read)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running as a Jupyter notebook - intended for development only!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:26: DeprecationWarning:\n",
      "\n",
      "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "\n",
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:27: DeprecationWarning:\n",
      "\n",
      "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "import os\n",
    "\n",
    "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
    "DEVELOPMENT_MODE = False\n",
    "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "\n",
    "    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
    "    # # Install another version of node that makes PySvelte work way faster\n",
    "    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
    "    # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
    "except:\n",
    "    IN_COLAB = False\n",
    "\n",
    "if not IN_GITHUB and not IN_COLAB:\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    ipython = get_ipython()\n",
    "    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    ipython.magic(\"load_ext autoreload\")\n",
    "    ipython.magic(\"autoreload 2\")\n",
    "\n",
    "if IN_COLAB:\n",
    "    %pip install transformer_lens\n",
    "    %pip install circuitsvis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using renderer: colab\n"
     ]
    }
   ],
   "source": [
    "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
    "import plotly.io as pio\n",
    "\n",
    "if IN_COLAB or not DEVELOPMENT_MODE:\n",
    "    pio.renderers.default = \"colab\"\n",
    "else:\n",
    "    pio.renderers.default = \"notebook_connected\"\n",
    "print(f\"Using renderer: {pio.renderers.default}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div id=\"circuits-vis-8c91db10-74f4\" style=\"margin: 15px 0;\"/>\n",
       "    <script crossorigin type=\"module\">\n",
       "    import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
       "    render(\n",
       "      \"circuits-vis-8c91db10-74f4\",\n",
       "      Hello,\n",
       "      {\"name\": \"Neel\"}\n",
       "    )\n",
       "    </script>"
      ],
      "text/plain": [
       "<circuitsvis.utils.render.RenderedHTML at 0x13a9760d0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import circuitsvis as cv\n",
    "\n",
    "# Testing that the library works\n",
    "cv.examples.hello(\"Neel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import stuff\n",
    "import torch\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "from transformer_lens import HookedEncoder, BertNextSentencePrediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x2a285a790>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BERT\n",
    "\n",
    "In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling and Next Sentence Prediction task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n",
      "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moving model to device:  mps\n",
      "Loaded pretrained model bert-base-cased into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Masked Language Modelling\n",
    "Use the \"[MASK]\" token to mask any tokens which you would like the model to predict.  \n",
    "When specifying return_type=\"predictions\" the prediction of the model is returned, alternatively (and by default) the function returns logits.  \n",
    "You can also specify None as return type for which nothing is returned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: The [MASK] is bright today.\n",
      "Prediction: \"sun\"\n"
     ]
    }
   ],
   "source": [
    "prompt = \"The [MASK] is bright today.\"\n",
    "\n",
    "prediction = bert(prompt, return_type=\"predictions\")\n",
    "\n",
    "print(f\"Prompt: {prompt}\")\n",
    "print(f'Prediction: \"{prediction}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also input a list of prompts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: ['The [MASK] is bright today.', 'She [MASK] to the store.', 'The dog [MASK] the ball.']\n",
      "Prediction: \"['Prediction 0: sun', 'Prediction 1: went', 'Prediction 2: caught']\"\n"
     ]
    }
   ],
   "source": [
    "prompts = [\"The [MASK] is bright today.\", \"She [MASK] to the store.\", \"The dog [MASK] the ball.\"]\n",
    "\n",
    "predictions = bert(prompts, return_type=\"predictions\")\n",
    "\n",
    "print(f\"Prompt: {prompts}\")\n",
    "print(f'Prediction: \"{predictions}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Sentence Prediction\n",
    "To carry out Next Sentence Prediction, you have to use the class BertNextSentencePrediction, and pass a HookedEncoder in its constructor.  \n",
    "Then, create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function.  \n",
    "The model will then predict the probability of the sentence at position 1 following (i.e. being the next sentence) to the sentence at position 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence A: A man walked into a grocery store.\n",
      "Sentence B: He bought an apple.\n",
      "Prediction: \"The sentences are sequential\"\n"
     ]
    }
   ],
   "source": [
    "nsp = BertNextSentencePrediction(bert)\n",
    "sentence_a = \"A man walked into a grocery store.\"\n",
    "sentence_b = \"He bought an apple.\"\n",
    "\n",
    "input = [sentence_a, sentence_b]\n",
    "\n",
    "predictions = nsp(input, return_type=\"predictions\")\n",
    "\n",
    "print(f\"Sentence A: {sentence_a}\")\n",
    "print(f\"Sentence B: {sentence_b}\")\n",
    "print(f'Prediction: \"{predictions}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inputting tokens directly\n",
    "You can also input tokens instead of a string or a list of strings into the model, which could look something like this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: The [MASK] is bright today.\n",
      "Prediction: \"sun\"\n"
     ]
    }
   ],
   "source": [
    "prompt = \"The [MASK] is bright today.\"\n",
    "\n",
    "tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
    "logits = bert(tokens) # Since we are not specifying return_type, we get the logits\n",
    "logprobs = logits[tokens == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
    "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
    "\n",
    "print(f\"Prompt: {prompt}\")\n",
    "print(f'Prediction: \"{prediction}\"')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Well done, BERT!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
